import ShiftingWindowSetting as sw
import random
import torch
import numpy as np


class EntropySS(sw.CLLearningAlgo):

    per_class_mem = {}

    def __init__(self, args, mem_size=1000, replay_batch_size=10):
        super().__init__(args=args)
        self.mem_size = mem_size
        self.remaining_space = mem_size
        self.seen_count = 0
        self.next = 0
        self.w = 0
        self.full = False
        self.replay_batch_size = replay_batch_size

    # def regularised_loss_fn(self, X, Y):
    #    return self.calc_reg_loss_term()

    def calc_reg_loss_term(self):
        mem = []
        for y in self.per_class_mem:
            mem += [(data[0], y, data[1]) for data in self.per_class_mem[y]]
        if len(mem) < self.replay_batch_size:
            return torch.zeros(1, device=self.device, requires_grad=True)
        X, Y, t = list(zip(*random.sample(mem, self.replay_batch_size)))
        X, Y = torch.stack(X, dim=0), torch.tensor(Y)
        X, Y = X.to(self.device), Y.to(self.device)
        per_point_nullClasses = [self.calc_null_classes(t[i], self.task_stream.classes) for i in range(len(t))]
        #per_point_nullClasses = sw.calc_per_point_nullclasses_for_task_inc_setting(self.task_stream, X, Y, self.window_len)
        return self.loss_fn(sw.calc_multi_head_model_output(self.model, X, per_point_nullClasses), Y)

    def after_optimiser_step(self):
        self._update_memory()

    def _update_memory(self):
        batch = [(self.batch[0][i], self.batch[1][i].item(), self.task_id) for i in range(self.batch[0].shape[0])]
        if not self.full:
            can_remove_count = min(len(batch), self.remaining_space)
            for i in range(can_remove_count):
                if batch[i][1] not in self.per_class_mem:
                    self.per_class_mem[batch[i][1]] = []
                self.per_class_mem[batch[i][1]].append((batch[i][0], batch[i][2]))

            self.remaining_space -= can_remove_count
            self.seen_count += can_remove_count
            if self.remaining_space == 0:
                self.full = True
            if can_remove_count == len(batch):
                return
            batch = batch[can_remove_count:]

        for x, y, t in batch:
            majority_class = 0
            max_n = 0
            for m_y in self.per_class_mem:
                n = len(self.per_class_mem[m_y])
                if n >= max_n:
                    majority_class = m_y
                    max_n = n

            class_data = torch.stack([m_x.view(-1) for m_x, _ in self.per_class_mem[majority_class]])
            dists = torch.cdist(class_data, class_data, p=2)
            dists += 10e10*torch.eye(class_data.shape[0])
            probs = dists.min(dim=1)[0]
            probs = 1-probs/probs.max()
            probs /= probs.sum()
            probs = probs.cpu().numpy()

            del self.per_class_mem[majority_class][np.random.choice(max_n, p=probs)]
            if y not in self.per_class_mem:
                self.per_class_mem[y] = []
            self.per_class_mem[y].append((x, t))










